Exploratory Data Analysis¶
Vamos a trabajar con el conjunto de datos de Heart Attack, el objetivo es predecir bajo que escenario es más probable que un paciente pueda tener un ataque al corazón Un experto en medicina cardiovasuclar puede predecir esto sin hacer uso de Machine Learning, pero probablemente no instantáneamente, ¡y ciertamente no si estamos tratando con cientos o miles de muestras!.
A continuación una breve explicación de las variables del dataset:
- age: Age of the patient
- sex: Sex of the patient
- cp: Chest pain type ~ 0 = Typical Angina, 1 = Atypical Angina, 2 = Non-anginal Pain, 3 = Asymptomatic
- trtbps: Resting blood pressure (in mm Hg)
- chol: Cholestoral in mg/dl fetched via BMI sensor
- fbs: (fasting blood sugar > 120 mg/dl) ~ 1 = True, 0 = False
- restecg: Resting electrocardiographic results ~ 0 = Normal, 1 = ST-T wave normality, 2 = Left ventricular hypertrophy
- thalachh: Maximum heart rate achieved
- oldpeak: Previous peak
- slp: Slope
- caa: Number of major vessels
- thall: Thalium Stress Test result ~ (0,3)
- exng: Exercise induced angina ~ 1 = Yes, 0 = No
- output: Target variable
# Archivo Heart Attack.csv - ¿Cuales son los factores que pueden incrementar o disminuir la probabilidad de un ataque al corazón?
import pandas as pd
df = pd.read_csv('Heart Attack.csv')
df
| age | sex | cp | trtbps | chol | fbs | restecg | thalachh | exng | oldpeak | slp | caa | thall | output | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 | 1 |
| 1 | 37 | 1 | 2 | 130 | 250 | 0 | 1 | 187 | 0 | 3.5 | 0 | 0 | 2 | 1 |
| 2 | 41 | 0 | 1 | 130 | 204 | 0 | 0 | 172 | 0 | 1.4 | 2 | 0 | 2 | 1 |
| 3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 | 1 |
| 4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 298 | 57 | 0 | 0 | 140 | 241 | 0 | 1 | 123 | 1 | 0.2 | 1 | 0 | 3 | 0 |
| 299 | 45 | 1 | 3 | 110 | 264 | 0 | 1 | 132 | 0 | 1.2 | 1 | 0 | 3 | 0 |
| 300 | 68 | 1 | 0 | 144 | 193 | 1 | 1 | 141 | 0 | 3.4 | 1 | 2 | 3 | 0 |
| 301 | 57 | 1 | 0 | 130 | 131 | 0 | 1 | 115 | 1 | 1.2 | 1 | 1 | 3 | 0 |
| 302 | 57 | 0 | 1 | 130 | 236 | 0 | 0 | 174 | 0 | 0.0 | 1 | 1 | 2 | 0 |
303 rows × 14 columns
import sweetviz as sv
reporte = sv.analyze(df)
reporte
| | [ 0%] 00:00 ->…
<sweetviz.dataframe_report.DataframeReport at 0x26330384820>
# Hacer EDA (Exploratory Data Analysis) suele ser un tanto laborioso dependiendo del detalle al que se quiera llevar, pero prueba la siguiente librería, puede que a partir de ahora, tu EDA sea más fácil ;)
from dataprep.eda import create_report
create_report(df)
0%| | 0/1789 [00:00<…
C:\Users\Eduardo\anaconda3\envs\rfm_project\lib\site-packages\dask\core.py:119: RuntimeWarning: invalid value encountered in divide return func(*(_execute_task(a, cache) for a in args))
Dataset Statistics
| Number of Variables | 14 |
|---|---|
| Number of Rows | 303 |
| Missing Cells | 0 |
| Missing Cells (%) | 0.0% |
| Duplicate Rows | 1 |
| Duplicate Rows (%) | 0.3% |
| Total Size in Memory | 33.3 KB |
| Average Row Size in Memory | 112.4 B |
| Variable Types |
|
Dataset Insights
| trtbps is skewed | Skewed |
|---|---|
| oldpeak is skewed | Skewed |
| sex has constant length 1 | Constant Length |
| cp has constant length 1 | Constant Length |
| fbs has constant length 1 | Constant Length |
| restecg has constant length 1 | Constant Length |
| exng has constant length 1 | Constant Length |
| slp has constant length 1 | Constant Length |
| caa has constant length 1 | Constant Length |
| thall has constant length 1 | Constant Length |
| output has constant length 1 | Constant Length |
|---|---|
| oldpeak has 99 (32.67%) zeros | Zeros |
- 1
- 2
age
numerical
| Approximate Distinct Count | 41 |
|---|---|
| Approximate Unique (%) | 13.5% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Infinite | 0 |
| Infinite (%) | 0.0% |
| Memory Size | 4848 |
| Mean | 54.3663 |
| Minimum | 29 |
| Maximum | 77 |
| Zeros | 0 |
| Zeros (%) | 0.0% |
| Negatives | 0 |
| Negatives (%) | 0.0% |
- age is skewed left (γ1 = -0.2015)
Quantile Statistics
| Minimum | 29 |
|---|---|
| 5-th Percentile | 39.1 |
| Q1 | 47.5 |
| Median | 55 |
| Q3 | 61 |
| 95-th Percentile | 68 |
| Maximum | 77 |
| Range | 48 |
| IQR | 13.5 |
Descriptive Statistics
| Mean | 54.3663 |
|---|---|
| Standard Deviation | 9.0821 |
| Variance | 82.4846 |
| Sum | 16473 |
| Skewness | -0.2015 |
| Kurtosis | -0.553 |
| Coefficient of Variation | 0.1671 |
sex
categorical
| Approximate Distinct Count | 2 |
|---|---|
| Approximate Unique (%) | 0.7% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
- The largest value (1) is over 2.16 times larger than the second largest value (0)
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 1 |
|---|---|
| 2nd row | 1 |
| 3rd row | 0 |
| 4th row | 1 |
| 5th row | 0 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (1, 0) take over 50.0%
- The largest value (1) is over 2.16 times larger than the second largest value (0)
- sex has words of constant length
cp
categorical
| Approximate Distinct Count | 4 |
|---|---|
| Approximate Unique (%) | 1.3% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
- The largest value (0) is over 1.64 times larger than the second largest value (2)
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 3 |
|---|---|
| 2nd row | 2 |
| 3rd row | 1 |
| 4th row | 1 |
| 5th row | 0 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (0, 2) take over 50.0%
- The largest value (0) is over 1.64 times larger than the second largest value (2)
- cp has words of constant length
trtbps
numerical
| Approximate Distinct Count | 49 |
|---|---|
| Approximate Unique (%) | 16.2% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Infinite | 0 |
| Infinite (%) | 0.0% |
| Memory Size | 4848 |
| Mean | 131.6238 |
| Minimum | 94 |
| Maximum | 200 |
| Zeros | 0 |
| Zeros (%) | 0.0% |
| Negatives | 0 |
| Negatives (%) | 0.0% |
- trtbps is skewed right (γ1 = 0.7102)
Quantile Statistics
| Minimum | 94 |
|---|---|
| 5-th Percentile | 108 |
| Q1 | 120 |
| Median | 130 |
| Q3 | 140 |
| 95-th Percentile | 160 |
| Maximum | 200 |
| Range | 106 |
| IQR | 20 |
Descriptive Statistics
| Mean | 131.6238 |
|---|---|
| Standard Deviation | 17.5381 |
| Variance | 307.5865 |
| Sum | 39882 |
| Skewness | 0.7102 |
| Kurtosis | 0.8941 |
| Coefficient of Variation | 0.1332 |
- trtbps is not normally distributed (p-value 1.9245365572085903e-11)
- trtbps has 9 outliers
chol
numerical
| Approximate Distinct Count | 152 |
|---|---|
| Approximate Unique (%) | 50.2% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Infinite | 0 |
| Infinite (%) | 0.0% |
| Memory Size | 4848 |
| Mean | 246.264 |
| Minimum | 126 |
| Maximum | 564 |
| Zeros | 0 |
| Zeros (%) | 0.0% |
| Negatives | 0 |
| Negatives (%) | 0.0% |
- chol is skewed right (γ1 = 1.1377)
Quantile Statistics
| Minimum | 126 |
|---|---|
| 5-th Percentile | 175 |
| Q1 | 211 |
| Median | 240 |
| Q3 | 274.5 |
| 95-th Percentile | 326.9 |
| Maximum | 564 |
| Range | 438 |
| IQR | 63.5 |
Descriptive Statistics
| Mean | 246.264 |
|---|---|
| Standard Deviation | 51.8308 |
| Variance | 2686.4267 |
| Sum | 74618 |
| Skewness | 1.1377 |
| Kurtosis | 4.4117 |
| Coefficient of Variation | 0.2105 |
- chol is not normally distributed (p-value 0.009062913506085566)
- chol has 5 outliers
fbs
categorical
| Approximate Distinct Count | 2 |
|---|---|
| Approximate Unique (%) | 0.7% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
- The largest value (0) is over 5.73 times larger than the second largest value (1)
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 1 |
|---|---|
| 2nd row | 0 |
| 3rd row | 0 |
| 4th row | 0 |
| 5th row | 0 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (0, 1) take over 50.0%
- The largest value (0) is over 5.73 times larger than the second largest value (1)
- fbs has words of constant length
restecg
categorical
| Approximate Distinct Count | 3 |
|---|---|
| Approximate Unique (%) | 1.0% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 0 |
|---|---|
| 2nd row | 1 |
| 3rd row | 0 |
| 4th row | 1 |
| 5th row | 1 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (1, 0) take over 50.0%
- restecg has words of constant length
thalachh
numerical
| Approximate Distinct Count | 91 |
|---|---|
| Approximate Unique (%) | 30.0% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Infinite | 0 |
| Infinite (%) | 0.0% |
| Memory Size | 4848 |
| Mean | 149.6469 |
| Minimum | 71 |
| Maximum | 202 |
| Zeros | 0 |
| Zeros (%) | 0.0% |
| Negatives | 0 |
| Negatives (%) | 0.0% |
- thalachh is skewed left (γ1 = -0.5347)
Quantile Statistics
| Minimum | 71 |
|---|---|
| 5-th Percentile | 108.1 |
| Q1 | 133.5 |
| Median | 153 |
| Q3 | 166 |
| 95-th Percentile | 181.9 |
| Maximum | 202 |
| Range | 131 |
| IQR | 32.5 |
Descriptive Statistics
| Mean | 149.6469 |
|---|---|
| Standard Deviation | 22.9052 |
| Variance | 524.6464 |
| Sum | 45343 |
| Skewness | -0.5347 |
| Kurtosis | -0.08069 |
| Coefficient of Variation | 0.1531 |
- thalachh has 1 outliers
exng
categorical
| Approximate Distinct Count | 2 |
|---|---|
| Approximate Unique (%) | 0.7% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
- The largest value (0) is over 2.06 times larger than the second largest value (1)
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 0 |
|---|---|
| 2nd row | 0 |
| 3rd row | 0 |
| 4th row | 0 |
| 5th row | 1 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (0, 1) take over 50.0%
- The largest value (0) is over 2.06 times larger than the second largest value (1)
- exng has words of constant length
oldpeak
numerical
| Approximate Distinct Count | 40 |
|---|---|
| Approximate Unique (%) | 13.2% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Infinite | 0 |
| Infinite (%) | 0.0% |
| Memory Size | 4848 |
| Mean | 1.0396 |
| Minimum | 0 |
| Maximum | 6.2 |
| Zeros | 99 |
| Zeros (%) | 32.7% |
| Negatives | 0 |
| Negatives (%) | 0.0% |
- oldpeak is skewed right (γ1 = 1.2634)
Quantile Statistics
| Minimum | 0 |
|---|---|
| 5-th Percentile | 0 |
| Q1 | 0 |
| Median | 0.8 |
| Q3 | 1.6 |
| 95-th Percentile | 3.4 |
| Maximum | 6.2 |
| Range | 6.2 |
| IQR | 1.6 |
Descriptive Statistics
| Mean | 1.0396 |
|---|---|
| Standard Deviation | 1.1611 |
| Variance | 1.3481 |
| Sum | 315 |
| Skewness | 1.2634 |
| Kurtosis | 1.5302 |
| Coefficient of Variation | 1.1168 |
- oldpeak is not normally distributed (p-value 3.2207216050407436e-22)
- oldpeak has 5 outliers
slp
categorical
| Approximate Distinct Count | 3 |
|---|---|
| Approximate Unique (%) | 1.0% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 0 |
|---|---|
| 2nd row | 0 |
| 3rd row | 2 |
| 4th row | 2 |
| 5th row | 2 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (2, 1) take over 50.0%
- slp has words of constant length
caa
categorical
| Approximate Distinct Count | 5 |
|---|---|
| Approximate Unique (%) | 1.7% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
- The largest value (0) is over 2.69 times larger than the second largest value (1)
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 0 |
|---|---|
| 2nd row | 0 |
| 3rd row | 0 |
| 4th row | 0 |
| 5th row | 0 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (0, 1) take over 50.0%
- The largest value (0) is over 2.69 times larger than the second largest value (1)
- caa has words of constant length
thall
categorical
| Approximate Distinct Count | 4 |
|---|---|
| Approximate Unique (%) | 1.3% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 1 |
|---|---|
| 2nd row | 2 |
| 3rd row | 2 |
| 4th row | 2 |
| 5th row | 2 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (2, 3) take over 50.0%
- thall has words of constant length
output
categorical
| Approximate Distinct Count | 2 |
|---|---|
| Approximate Unique (%) | 0.7% |
| Missing | 0 |
| Missing (%) | 0.0% |
| Memory Size | 19998 |
Length
| Mean | 1 |
|---|---|
| Standard Deviation | 0 |
| Median | 1 |
| Minimum | 1 |
| Maximum | 1 |
Sample
| 1st row | 1 |
|---|---|
| 2nd row | 1 |
| 3rd row | 1 |
| 4th row | 1 |
| 5th row | 1 |
Letter
| Count | 0 |
|---|---|
| Lowercase Letter | 0 |
| Space Separator | 0 |
| Uppercase Letter | 0 |
| Dash Punctuation | 0 |
| Decimal Number | 303 |
- The top 2 categories (1, 0) take over 50.0%
- output has words of constant length
k-Nearest Neighbors¶
Habiendo hecho un Análisis Exploratorio de los factores que pueden o no tener más posibilidad de un ataque al corazón, es hora de crear tu primer clasificador!!! usando el algoritmo de k-NN.
Nota: es importante garantizar que los datos esten en el formato requerido por la librería de scikit-learn. La información debe estar en una matriz en la que cada columna sea una variable y cada fila una observación diferente, en este caso, el registro de análisis clinico por paciente. Y la variable objetivo debe ser una sola columna con el mismo número de observaciones.
# Importa la librería para un clasificador k-NN de sklearn
from sklearn.neighbors import KNeighborsClassifier
# Crea dos arreglos "X", "y" que contengan los valores de las variables independientes y la variable objetivo
y = df.output.values
x_data = df.drop(['output'], axis=1)
# Variable objetivo
y
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int64)
# Variables independientes
x_data
| age | sex | cp | trtbps | chol | fbs | restecg | thalachh | exng | oldpeak | slp | caa | thall | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 63 | 1 | 3 | 145 | 233 | 1 | 0 | 150 | 0 | 2.3 | 0 | 0 | 1 |
| 1 | 37 | 1 | 2 | 130 | 250 | 0 | 1 | 187 | 0 | 3.5 | 0 | 0 | 2 |
| 2 | 41 | 0 | 1 | 130 | 204 | 0 | 0 | 172 | 0 | 1.4 | 2 | 0 | 2 |
| 3 | 56 | 1 | 1 | 120 | 236 | 0 | 1 | 178 | 0 | 0.8 | 2 | 0 | 2 |
| 4 | 57 | 0 | 0 | 120 | 354 | 0 | 1 | 163 | 1 | 0.6 | 2 | 0 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 298 | 57 | 0 | 0 | 140 | 241 | 0 | 1 | 123 | 1 | 0.2 | 1 | 0 | 3 |
| 299 | 45 | 1 | 3 | 110 | 264 | 0 | 1 | 132 | 0 | 1.2 | 1 | 0 | 3 |
| 300 | 68 | 1 | 0 | 144 | 193 | 1 | 1 | 141 | 0 | 3.4 | 1 | 2 | 3 |
| 301 | 57 | 1 | 0 | 130 | 131 | 0 | 1 | 115 | 1 | 1.2 | 1 | 1 | 3 |
| 302 | 57 | 0 | 1 | 130 | 236 | 0 | 0 | 174 | 0 | 0.0 | 1 | 1 | 2 |
303 rows × 13 columns
#Normalizacion
import numpy as np
x = (x_data - np.min(x_data)) / (np.max(x_data)- np.min(x_data))
x
C:\Users\Eduardo\anaconda3\envs\rfm_project\lib\site-packages\numpy\core\fromnumeric.py:84: FutureWarning: In a future version, DataFrame.min(axis=None) will return a scalar min over the entire DataFrame. To retain the old behavior, use 'frame.min(axis=0)' or just 'frame.min()' return reduction(axis=axis, out=out, **passkwargs) C:\Users\Eduardo\anaconda3\envs\rfm_project\lib\site-packages\numpy\core\fromnumeric.py:84: FutureWarning: In a future version, DataFrame.max(axis=None) will return a scalar max over the entire DataFrame. To retain the old behavior, use 'frame.max(axis=0)' or just 'frame.max()' return reduction(axis=axis, out=out, **passkwargs) C:\Users\Eduardo\anaconda3\envs\rfm_project\lib\site-packages\numpy\core\fromnumeric.py:84: FutureWarning: In a future version, DataFrame.min(axis=None) will return a scalar min over the entire DataFrame. To retain the old behavior, use 'frame.min(axis=0)' or just 'frame.min()' return reduction(axis=axis, out=out, **passkwargs)
| age | sex | cp | trtbps | chol | fbs | restecg | thalachh | exng | oldpeak | slp | caa | thall | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.708333 | 1.0 | 1.000000 | 0.481132 | 0.244292 | 1.0 | 0.0 | 0.603053 | 0.0 | 0.370968 | 0.0 | 0.00 | 0.333333 |
| 1 | 0.166667 | 1.0 | 0.666667 | 0.339623 | 0.283105 | 0.0 | 0.5 | 0.885496 | 0.0 | 0.564516 | 0.0 | 0.00 | 0.666667 |
| 2 | 0.250000 | 0.0 | 0.333333 | 0.339623 | 0.178082 | 0.0 | 0.0 | 0.770992 | 0.0 | 0.225806 | 1.0 | 0.00 | 0.666667 |
| 3 | 0.562500 | 1.0 | 0.333333 | 0.245283 | 0.251142 | 0.0 | 0.5 | 0.816794 | 0.0 | 0.129032 | 1.0 | 0.00 | 0.666667 |
| 4 | 0.583333 | 0.0 | 0.000000 | 0.245283 | 0.520548 | 0.0 | 0.5 | 0.702290 | 1.0 | 0.096774 | 1.0 | 0.00 | 0.666667 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 298 | 0.583333 | 0.0 | 0.000000 | 0.433962 | 0.262557 | 0.0 | 0.5 | 0.396947 | 1.0 | 0.032258 | 0.5 | 0.00 | 1.000000 |
| 299 | 0.333333 | 1.0 | 1.000000 | 0.150943 | 0.315068 | 0.0 | 0.5 | 0.465649 | 0.0 | 0.193548 | 0.5 | 0.00 | 1.000000 |
| 300 | 0.812500 | 1.0 | 0.000000 | 0.471698 | 0.152968 | 1.0 | 0.5 | 0.534351 | 0.0 | 0.548387 | 0.5 | 0.50 | 1.000000 |
| 301 | 0.583333 | 1.0 | 0.000000 | 0.339623 | 0.011416 | 0.0 | 0.5 | 0.335878 | 1.0 | 0.193548 | 0.5 | 0.25 | 1.000000 |
| 302 | 0.583333 | 0.0 | 0.333333 | 0.339623 | 0.251142 | 0.0 | 0.0 | 0.786260 | 0.0 | 0.000000 | 0.5 | 0.25 | 0.666667 |
303 rows × 13 columns
#Bases de entrenamiento y prueba
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x,y, test_size = 0.30, random_state = 1)
x_train = np.array(x_train)
x_test = np.array(x_test)
# Crea un clasificador k-NN con 6 vecinos
knn = KNeighborsClassifier(n_neighbors=6)
# Ajusta el clasificador a las variables
knn.fit(x_train, y_train)
KNeighborsClassifier(n_neighbors=6)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier(n_neighbors=6)
Predicción¶
Una vez que entrenamos al clasificador k-NN, ahora lo podemos usar para predecir un nuevo registro. Para este caso, no hay datos sin clasificar disponibles ya que todos se usaron para entrenar al modelo. Para poder calcular una predicción, vamos a usar el método .predict() pero, para esto vamos a simular una observación completamente nueva
# Crea un arreglo simulando una observación
X_new = [41, 0, 2, 145, 233, 1, 0, 150, 0, 3.5, 1, 0, 2]
X_new = np.array(X_new).reshape(1, -1)
# Predice la clasificación para el arreglo que creaste
y_new_pred = knn.predict(X_new)
print("Prediction: {}".format(y_new_pred))
Prediction: [0]
Reconocimiento de digitos¶
Hasta ahora, solo hemos realizado una clasificación binaria, ya que la variable objetivo tenía dos resultados posibles. En los siguientes ejercicios, trabajarás con el conjunto de datos de reconocimiento de dígitos MNIST, que tiene 10 clases, ¡los dígitos del 0 al 9! Una versión reducida del conjunto de datos MNIST es uno de los conjuntos de datos incluidos en scikit-learn
Cada muestra de este conjunto de datos es una imagen de 28x28 que representa un dígito escrito a mano. Cada píxel está representado por un número entero en el rango de 1 a 784, lo que indica niveles variables de negro.

# Importa el archivo de MNIST
digits = pd.read_csv('MNIST.csv')
digits.head(10)
| label | pixel0 | pixel1 | pixel2 | pixel3 | pixel4 | pixel5 | pixel6 | pixel7 | pixel8 | ... | pixel774 | pixel775 | pixel776 | pixel777 | pixel778 | pixel779 | pixel780 | pixel781 | pixel782 | pixel783 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 2 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 3 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 5 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 6 | 7 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 7 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 8 | 5 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 9 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
10 rows × 785 columns
# Crea una variable 'cols' para hacer referencia a todas las columnas que contienen la palabra 'pixel'
cols = [col for col in digits.columns if 'pixel' in col]
cols
['pixel0', 'pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783']
# Vamos a imprimir un digito
import matplotlib.pyplot as plt
i = 0
print("El número es:", digits.loc[i, 'label'])
plt.imshow(digits.loc[i, cols].values.reshape((28,28)).astype(float), cmap=plt.cm.gray_r, interpolation='nearest')
El número es: 1
<matplotlib.image.AxesImage at 0x26339edc040>
Train/Test¶
Una de las principales diferencias entre la Estadística Clasica y el Machine Learning es la división del conjunto de datos en conjuntos de entrenamiento y prueba, con el objetivo de medir y cuantificar la precisión y el nivel de error en los datos que de alguna manera el modelo "No ha visto". A continuación crearemos nuestros conjuntos de entrenamiento y prueba con el método train_test_split y mediremos cual es el nivel de precisión de nuestro modelo. El objetivo es predecir cual es el digito dada una imagen!!!. Para lo cual entrenaremos un clasificador k-NN a los datos de entrenamiento y luego calcularemos su precisión usando el método accuracy_score() en los datos de prueba ¿Como crees que en un modelo de Clasificación se calcule su precisión?. Parece bastante dificil, pero no lo es ;)
# Importa la librería para entrenamiento y prueba de datos y la librería para calcular la precisión
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Crea los arreglos para las variables independientes y la variable objetivo
X = digits.drop(['label'], axis=1).values
y = digits['label'].values
import numpy as np
x_train = np.array(x_train)
x_test = np.array(x_test)
# Instancia un clasificador k-NN con 14 vecinos
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5)
print(x_train.shape)
print(y_train.shape)
(212, 13) (212,)
# Divide los arreglos anteriores en conjuntos de training y test en una proporcion del 70/30
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
# Ajusta (Entrenamiento) el clasificador en el conjunto de entrenamiento
knn.fit(x_train, y_train)
KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier()
# Calcular las predicciones sobre el conjunto de prueba
y_pred = knn.predict(x_test)
y_pred
y_true = y_test
# Verificar la precisión del modelo
from sklearn.metrics import accuracy_score
print(accuracy_score(y_true, y_pred))
0.966031746031746
Reconocimiento de tu imagen¶
Con todo lo anterior, podemos hacer el reconocimiento de cualquier digito que dibujes, ¿Estás list@?
# Vamos a visualizar la imagen de un número que vas a crear en tu computador con la aplicación de paint, ésta imagen debe de tener un fondo negro y ser pintada en blanco, encontrarás un ejemplo en el repositorio
import matplotlib.pyplot as plt
image = plt.imread('/Users/Eduardo/Desktop/Numero 4.jpg') # Coloca aquí la ruta de la imagen que hayas creado en formato jpg o png
plt.imshow(image)
<matplotlib.image.AxesImage at 0x26339e6ea30>
# Con esta libreria transformaremos la imagen creada a un formato de 28x28 pixeles
from PIL import Image
pil = Image.open('/Users/Eduardo/Desktop/Numero 4.jpg')
image_resize = pil.resize((28, 28))
# Vamos transformar la nueva imagen en un array donde se almacenara la información de los pixeles
pixels = np.asarray(image_resize)
# Necesitamos hacer algunas configuraciones a la imagen debido al formato de datos que esta alimentando al modelo y a la configuración de sklearn
arr = pixels.transpose(2, 0, 1).reshape(-1, pixels.shape[1])[0:28]
image_final = arr.ravel().reshape(1, -1)
# Calcula la predicción del modelo con el número que creaste, ¿Fue la clasificación correcta? :O
print("El número es:", knn.predict(image_final))
plt.imshow(arr, cmap=plt.cm.gray_r, interpolation='nearest')
El número es: [1]
<matplotlib.image.AxesImage at 0x2633a02bbe0>
Overfit and Underfit¶
¿Cual es mi numero ideal para elegir el parametro k? Vamos a calcular los valores de precisión para los conjuntos e entrenamiento y prueba para una rango de valores k. Al observar cómo difieren estos valores podremos observar cual es el mejor parametro sin caer en un Overfit o un Underfit.
# Coniguración de arreglos iniciales
neighbors = np.arange(1, 9)
train_accuracy = np.empty(len(neighbors))
test_accuracy = np.empty(len(neighbors))
# Loop para diferentes valores de k
for i, k in enumerate(neighbors):
# Clasificador k-NN para el parametro k
knn = KNeighborsClassifier(n_neighbors=k)
# Ajuste del clasificador al dataset de entrenamiento
knn.fit(x_train, y_train)
# Calculo de precision sobre el dataset de entrenamiento
train_accuracy[i] = knn.score(x_train, y_train)
# Calculo de precision sobre el dataset de prueba
test_accuracy[i] = knn.score(x_test, y_test)
# Grafico para encontrar un valor optimo de k
plt.plot(neighbors, test_accuracy, label = 'Testing Accuracy')
plt.plot(neighbors, train_accuracy, label = 'Training Accuracy')
plt.title('k-NN: by Number of Neighbors')
plt.xlabel('Number of Neighbors')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
Regresión Logística¶
Haz la predicción de tu imagen, pero esta vez por medio de una Regresión Logística, ¿Cuál de los dos modelos te da mejores resultados?
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Divide los datos en características (X) y etiquetas (y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Entrena el modelo
logreg = LogisticRegression(max_iter=1000)
logreg.fit(X_train, y_train)
C:\Users\Eduardo\anaconda3\envs\rfm_project\lib\site-packages\sklearn\linear_model\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
LogisticRegression(max_iter=1000)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression(max_iter=1000)
# Haz predicciones
y_pred = logreg.predict(X_test)
y_pred
array([8, 1, 9, ..., 3, 0, 9], dtype=int64)
# Evalúa el modelo
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
Accuracy: 0.9001190476190476
y_pred2 = logreg.predict(image_final)
y_pred2
array([5], dtype=int64)
